import itertools
from copy import deepcopy
from enum import Enum

import numpy as np


# A typesafe enumeration of the types of endpoints that are permitted in
# Tetrad-style graphs: tail (--) null (-), arrow (->), circle (-o) and star (-*).
# 'TAIL_AND_ARROW' and 'ARROW_AND_ARROW' means there are two types of edges (<-> and -->)
# between two nodes
class Endpoint(Enum):
    TAIL = -1
    NULL = 0
    ARROW = 1
    CIRCLE = 2
    STAR = 3
    TAIL_AND_ARROW = 4
    ARROW_AND_ARROW = 5

    # Prints out the name of the type
    def __str__(self):
        return self.name


def feval(parameters: list):
    return parameters[0].score(parameters[2], parameters[3])

def score_g(Data, G, score_func, parameters):  # calculate the score for the current G
    # here G is a DAG
    score = 0
    for i, node in enumerate(G.get_nodes()):
        PA = G.get_parents(node)
        PAi = list(map(lambda node_PA: G.node_map[node_PA], PA))
        delta_score = feval([score_func, Data, i, PAi, parameters])
        score = score + delta_score
    return score


def Combinatorial(T0):
    # sub = Combinatorial (T0); % find all the subsets of T0
    sub = []
    count = 0
    if len(T0) == 0:
        sub.append(())  # a 1x0 empty matrix
    else:
        if len(T0) == 1:
            sub.append(())
            sub.append(T0)  # when T0 is a scale, it is a special case!!
        else:
            for n in range(len(T0) + 1):
                for S in list(itertools.combinations(T0, n)):
                    sub.append(S)
    return sub

def insert_validity_test1(G, i, j, T) -> int:
    # V=Insert_validity_test1(G, X, Y, T,1); % do validity test for the operator Insert; V=1 means valid, V=0 mean invalid;
    # here G is CPDAG
    V = 0

    # condition 1
    Tj = np.intersect1d(np.where(G.graph[:, j] == Endpoint.TAIL.value)[0],
                        np.where(G.graph[j, :] == Endpoint.TAIL.value)[0])  # neighbors of Xj
    Ti = np.union1d(np.where(G.graph[:, i] != Endpoint.NULL.value)[0],
                    np.where(G.graph[i, :] != Endpoint.NULL.value)[0])  # adjacent to Xi;
    NA = np.intersect1d(Tj, Ti)  # find the neighbours of Xj and are adjacent to Xi
    V = check_clique(G, list(np.union1d(NA, T).astype(int)))  # check whether it is a clique
    return V

def check_clique(G, subnode) -> int:  # check whether node subnode is a clique in G
    # here G is a CPDAG
    # the definition of clique here: a clique is defined in an undirected graph
    # when you ignore the directionality of any directed edges
    Gs = deepcopy(G.graph[np.ix_(subnode, subnode)])  # extract the subgraph
    ns = len(subnode)

    if ns == 0:
        s = 1
    else:
        row, col = np.where(Gs == Endpoint.ARROW.value)
        Gs[row, col] = Endpoint.TAIL.value
        Gs[col, row] = Endpoint.TAIL.value
        if np.all((np.eye(ns) - np.ones((ns, ns))) == Gs):  # check whether it is a clique
            s = 1
        else:
            s = 0
    return s

def insert_vc2_new(G, j, i, NAT):  # validity test for condition 2 of Insert operator
    # here G is CPDAG
    # Use Depth-first-Search
    start = j
    target = i
    # stack(1)=start; % initialize the stack
    stack = [{'value': start, 'pa': {}}]
    sign = 1  # If every semi-pathway contains a node in NAT, than sign=1;
    count = 1

    while len(stack):
        top = stack[0]
        stack = stack[1:]  # pop
        if top['value'] == target:  # if find the target, search that pathway to see whether NAT is in that pathway
            curr = top
            ss = 0
            while True:
                if len(curr['pa']):
                    if curr['pa']['value'] in NAT:  # contains a node in NAT
                        ss = 1
                        break
                else:
                    break
                curr = curr['pa']
            if not ss:  # do not include NAT
                sign = 0
                break
        else:
            child = np.concatenate((np.where(G.graph[:, top['value']] == Endpoint.ARROW.value)[0],
                                    np.intersect1d(np.where(G.graph[top['value'], :] == Endpoint.TAIL.value)[0],
                                                   np.where(G.graph[:, top['value']] == Endpoint.TAIL.value)[0])))
            sign_child = np.ones(len(child))
            # check each child, whether it has appeared before in the same pathway
            for k in range(len(child)):
                curr = top
                while True:
                    if len(curr['pa']):
                        if curr['pa']['value'] == child[k]:
                            sign_child[k] = 0  # has appeared in that path before
                            break
                    else:
                        break
                    curr = curr['pa']

            for k in range(len(sign_child)):
                if sign_child[k]:
                    stack.insert(0, {'value': child[k], 'pa': top})  # push
    return sign


def find_subset_include(s0, sub):
    # S = find_subset_include(sub(k),sub); %  find those subsets that include sub(k)
    if len(s0) == 0 or len(sub) == 0:
        Idx = np.ones(len(sub))
    else:
        Idx = np.zeros(len(sub))
        for i in range(len(sub)):
            tmp = set(s0).intersection(set(sub[i]))
            if len(tmp):
                if tmp == set(s0):
                    Idx[i] = 1
    return Idx


def insert_changed_score(Data, G, i, j, T, record_local_score, score_func, parameters):
    # calculate the changed score after the insert operator: i->j
    Tj = np.intersect1d(np.where(G.graph[:, j] == Endpoint.TAIL.value)[0],
                        np.where(G.graph[j, :] == Endpoint.TAIL.value)[0])  # neighbors of Xj
    Ti = np.union1d(np.where(G.graph[i, :] != Endpoint.NULL.value)[0],
                    np.where(G.graph[:, i] != Endpoint.NULL.value)[0])  # adjacent to Xi;
    NA = np.intersect1d(Tj, Ti)  # find the neighbours of Xj and are adjacent to Xi
    Paj = np.where(G.graph[j, :] == Endpoint.ARROW.value)[0]  # find the parents of Xj
    # the function local_score() calculates the local score
    tmp1 = np.union1d(NA, T).astype(int)
    tmp2 = np.union1d(tmp1, Paj)
    tmp3 = np.union1d(tmp2, [i]).astype(int)

    # before you calculate the local score, firstly you search in the
    # "record_local_score", to see whether you have calculated it before
    r = len(record_local_score[j])
    s1 = 0
    s2 = 0
    score1 = 0
    score2 = 0

    for r0 in range(r):
        if not np.setxor1d(record_local_score[j][r0][0:-1], tmp3).size:
            score1 = record_local_score[j][r0][-1]
            s1 = 1

        if not np.setxor1d(record_local_score[j][r0][0:-1],
                           tmp2).size:  # notice the difference between 0*0 empty matrix and 1*0 empty matrix
            score2 = record_local_score[j][r0][-1]
            s2 = 1
        else:
            if (not np.setxor1d(record_local_score[j][r0][0:-1], [-1]).size) and (not tmp2.size):
                score2 = record_local_score[j][r0][-1]
                s2 = 1

        if s1 and s2:
            break

    if not s1:
        score1 = feval([score_func, Data, j, tmp3, parameters])
        temp = list(tmp3)
        temp.append(score1)
        record_local_score[j].append(temp)

    if not s2:
        score2 = feval([score_func, Data, j, tmp2, parameters])
        # r = len(record_local_score[j])
        if len(tmp2) != 0:
            temp = list(tmp2)
            temp.append(score2)
            record_local_score[j].append(temp)
        else:
            record_local_score[j].append([-1, score2])

    ch_score = score1 - score2
    desc = [i, j, T]
    return ch_score, desc, record_local_score
